{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# Building a Tool-Calling Calculator Agent with Koog\n",
    "\n",
    "In this mini-tutorial we’ll build a calculator agent powered by **Koog** tool-calling.\n",
    "You’ll learn how to:\n",
    "- Design small, pure **tools** for arithmetic\n",
    "- Orchestrate **parallel** tool calls with Koog’s multiple-call strategy\n",
    "- Add lightweight **event logging** for transparency\n",
    "- Run with OpenAI (and optionally Ollama)\n",
    "\n",
    "We’ll keep the API tidy and idiomatic Kotlin, returning predictable results and handling edge cases (like division by zero) gracefully."
   ]
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Setup\n",
    "\n",
    "We assume you’re in a Kotlin Notebook environment with Koog available.\n",
    "Provide an LLM executor"
   ]
  },
  {
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2025-08-18T19:59:09.532848Z",
     "start_time": "2025-08-18T19:59:08.644052Z"
    }
   },
   "cell_type": "code",
   "source": [
    "%useLatestDescriptors\n",
    "%use koog\n",
    "\n",
    "\n",
    "val OPENAI_API_KEY = System.getenv(\"OPENAI_API_KEY\")\n",
    "    ?: error(\"Please set the OPENAI_API_KEY environment variable\")\n",
    "\n",
    "val executor = simpleOpenAIExecutor(OPENAI_API_KEY)"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Calculator Tools\n",
    "\n",
    "Tools are small, pure functions with clear contracts.\n",
    "We’ll use `Double` for better precision and format outputs consistently."
   ]
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-18T19:59:09.690170Z",
     "start_time": "2025-08-18T19:59:09.537329Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import ai.koog.agents.core.tools.annotations.Tool\n",
    "\n",
    "// Format helper: integers render cleanly, decimals keep reasonable precision.\n",
    "private fun Double.pretty(): String =\n",
    "    if (abs(this % 1.0) < 1e-9) this.toLong().toString() else \"%.10g\".format(this)\n",
    "\n",
    "@LLMDescription(\"Tools for basic calculator operations\")\n",
    "class CalculatorTools : ToolSet {\n",
    "\n",
    "    @Tool\n",
    "    @LLMDescription(\"Adds two numbers and returns the sum as text.\")\n",
    "    fun plus(\n",
    "        @LLMDescription(\"First addend.\") a: Double,\n",
    "        @LLMDescription(\"Second addend.\") b: Double\n",
    "    ): String = (a + b).pretty()\n",
    "\n",
    "    @Tool\n",
    "    @LLMDescription(\"Subtracts the second number from the first and returns the difference as text.\")\n",
    "    fun minus(\n",
    "        @LLMDescription(\"Minuend.\") a: Double,\n",
    "        @LLMDescription(\"Subtrahend.\") b: Double\n",
    "    ): String = (a - b).pretty()\n",
    "\n",
    "    @Tool\n",
    "    @LLMDescription(\"Multiplies two numbers and returns the product as text.\")\n",
    "    fun multiply(\n",
    "        @LLMDescription(\"First factor.\") a: Double,\n",
    "        @LLMDescription(\"Second factor.\") b: Double\n",
    "    ): String = (a * b).pretty()\n",
    "\n",
    "    @Tool\n",
    "    @LLMDescription(\"Divides the first number by the second and returns the quotient as text. Returns an error message on division by zero.\")\n",
    "    fun divide(\n",
    "        @LLMDescription(\"Dividend.\") a: Double,\n",
    "        @LLMDescription(\"Divisor (must not be zero).\") b: Double\n",
    "    ): String = if (abs(b) < 1e-12) {\n",
    "        \"ERROR: Division by zero\"\n",
    "    } else {\n",
    "        (a / b).pretty()\n",
    "    }\n",
    "}"
   ],
   "outputs": [],
   "execution_count": 2
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Tool Registry\n",
    "\n",
    "Expose our tools (plus two built-ins for interaction/logging)."
   ]
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-18T19:59:10.637863Z",
     "start_time": "2025-08-18T19:59:10.572195Z"
    }
   },
   "cell_type": "code",
   "source": [
    "val toolRegistry = ToolRegistry {\n",
    "    tool(AskUser)   // enables explicit user clarification when needed\n",
    "    tool(SayToUser) // allows the agent to present the final message to the user\n",
    "    tools(CalculatorTools())\n",
    "}"
   ],
   "outputs": [],
   "execution_count": 3
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Strategy: Multiple Tool Calls (with Optional Compression)\n",
    "\n",
    "This strategy lets the LLM propose **multiple tool calls at once** (e.g., `plus`, `minus`, `multiply`, `divide`) and then sends the results back.\n",
    "If the token usage grows too large, we **compress** the history of tool results before continuing."
   ]
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-18T19:59:11.885365Z",
     "start_time": "2025-08-18T19:59:11.529431Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import ai.koog.agents.core.environment.ReceivedToolResult\n",
    "\n",
    "object CalculatorStrategy {\n",
    "    private const val MAX_TOKENS_THRESHOLD = 1000\n",
    "\n",
    "    val strategy = strategy<String, String>(\"test\") {\n",
    "        val callLLM by nodeLLMRequestMultiple()\n",
    "        val executeTools by nodeExecuteMultipleTools(parallelTools = true)\n",
    "        val sendToolResults by nodeLLMSendMultipleToolResults()\n",
    "        val compressHistory by nodeLLMCompressHistory<List<ReceivedToolResult>>()\n",
    "\n",
    "        edge(nodeStart forwardTo callLLM)\n",
    "\n",
    "        // If the assistant produced a final answer, finish.\n",
    "        edge((callLLM forwardTo nodeFinish) transformed { it.first() } onAssistantMessage { true })\n",
    "\n",
    "        // Otherwise, run the tools LLM requested (possibly several in parallel).\n",
    "        edge((callLLM forwardTo executeTools) onMultipleToolCalls { true })\n",
    "\n",
    "        // If we’re getting large, compress past tool results before continuing.\n",
    "        edge(\n",
    "            (executeTools forwardTo compressHistory)\n",
    "                onCondition { llm.readSession { prompt.latestTokenUsage > MAX_TOKENS_THRESHOLD } }\n",
    "        )\n",
    "        edge(compressHistory forwardTo sendToolResults)\n",
    "\n",
    "        // Normal path: send tool results back to the LLM.\n",
    "        edge(\n",
    "            (executeTools forwardTo sendToolResults)\n",
    "                onCondition { llm.readSession { prompt.latestTokenUsage <= MAX_TOKENS_THRESHOLD } }\n",
    "        )\n",
    "\n",
    "        // LLM might request more tools after seeing results.\n",
    "        edge((sendToolResults forwardTo executeTools) onMultipleToolCalls { true })\n",
    "\n",
    "        // Or it can produce the final answer.\n",
    "        edge((sendToolResults forwardTo nodeFinish) transformed { it.first() } onAssistantMessage { true })\n",
    "    }\n",
    "}"
   ],
   "outputs": [],
   "execution_count": 4
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Agent Configuration\n",
    "\n",
    "A minimal, tool-forward prompt works well. Keep temperature low for deterministic math."
   ]
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-18T19:59:12.806070Z",
     "start_time": "2025-08-18T19:59:12.761564Z"
    }
   },
   "cell_type": "code",
   "source": [
    "val agentConfig = AIAgentConfig(\n",
    "    prompt = prompt(\"calculator\") {\n",
    "        system(\"You are a calculator. Always use the provided tools for arithmetic.\")\n",
    "    },\n",
    "    model = OpenAIModels.Chat.GPT4o,\n",
    "    maxAgentIterations = 50\n",
    ")"
   ],
   "outputs": [],
   "execution_count": 5
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-18T19:59:13.410273Z",
     "start_time": "2025-08-18T19:59:13.296675Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import ai.koog.agents.features.eventHandler.feature.handleEvents\n",
    "\n",
    "val agent = AIAgent(\n",
    "    promptExecutor = executor,\n",
    "    strategy = CalculatorStrategy.strategy,\n",
    "    agentConfig = agentConfig,\n",
    "    toolRegistry = toolRegistry\n",
    ") {\n",
    "    handleEvents {\n",
    "        onToolCallStarting { e ->\n",
    "            println(\"Tool called: ${e.tool.name}, args=${e.toolArgs}\")\n",
    "        }\n",
    "        onAgentExecutionFailed { e ->\n",
    "            println(\"Agent error: ${e.throwable.message}\")\n",
    "        }\n",
    "        onAgentCompleted { e ->\n",
    "            println(\"Final result: ${e.result}\")\n",
    "        }\n",
    "    }\n",
    "}"
   ],
   "outputs": [],
   "execution_count": 6
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Try It\n",
    "\n",
    "The agent should decompose the expression into parallel tool calls and return a neatly formatted result."
   ]
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-18T19:59:21.790882Z",
     "start_time": "2025-08-18T19:59:14.577289Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import kotlinx.coroutines.runBlocking\n",
    "\n",
    "runBlocking {\n",
    "    agent.run(\"(10 + 20) * (5 + 5) / (2 - 11)\")\n",
    "}\n",
    "// Expected final value ≈ -33.333..."
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tool called: plus, args=VarArgs(args={parameter #1 a of fun Line_4_jupyter.CalculatorTools.plus(kotlin.Double, kotlin.Double): kotlin.String=10.0, parameter #2 b of fun Line_4_jupyter.CalculatorTools.plus(kotlin.Double, kotlin.Double): kotlin.String=20.0})\n",
      "Tool called: plus, args=VarArgs(args={parameter #1 a of fun Line_4_jupyter.CalculatorTools.plus(kotlin.Double, kotlin.Double): kotlin.String=5.0, parameter #2 b of fun Line_4_jupyter.CalculatorTools.plus(kotlin.Double, kotlin.Double): kotlin.String=5.0})\n",
      "Tool called: minus, args=VarArgs(args={parameter #1 a of fun Line_4_jupyter.CalculatorTools.minus(kotlin.Double, kotlin.Double): kotlin.String=2.0, parameter #2 b of fun Line_4_jupyter.CalculatorTools.minus(kotlin.Double, kotlin.Double): kotlin.String=11.0})\n",
      "Tool called: multiply, args=VarArgs(args={parameter #1 a of fun Line_4_jupyter.CalculatorTools.multiply(kotlin.Double, kotlin.Double): kotlin.String=30.0, parameter #2 b of fun Line_4_jupyter.CalculatorTools.multiply(kotlin.Double, kotlin.Double): kotlin.String=10.0})\n",
      "Tool called: divide, args=VarArgs(args={parameter #1 a of fun Line_4_jupyter.CalculatorTools.divide(kotlin.Double, kotlin.Double): kotlin.String=1.0, parameter #2 b of fun Line_4_jupyter.CalculatorTools.divide(kotlin.Double, kotlin.Double): kotlin.String=-9.0})\n",
      "Tool called: divide, args=VarArgs(args={parameter #1 a of fun Line_4_jupyter.CalculatorTools.divide(kotlin.Double, kotlin.Double): kotlin.String=300.0, parameter #2 b of fun Line_4_jupyter.CalculatorTools.divide(kotlin.Double, kotlin.Double): kotlin.String=-9.0})\n",
      "Final result: The result of the expression \\((10 + 20) * (5 + 5) / (2 - 11)\\) is approximately \\(-33.33\\).\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "The result of the expression \\((10 + 20) * (5 + 5) / (2 - 11)\\) is approximately \\(-33.33\\)."
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 7
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Try Forcing Parallel Calls\n",
    "\n",
    "Ask the model to call all needed tools at once.\n",
    "You should still see a correct plan and stable execution."
   ]
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-18T19:59:25.670570Z",
     "start_time": "2025-08-18T19:59:21.797036Z"
    }
   },
   "cell_type": "code",
   "source": [
    "runBlocking {\n",
    "    agent.run(\"Use tools to calculate (10 + 20) * (5 + 5) / (2 - 11). Please call all the tools at once.\")\n",
    "}"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tool called: plus, args=VarArgs(args={parameter #1 a of fun Line_4_jupyter.CalculatorTools.plus(kotlin.Double, kotlin.Double): kotlin.String=10.0, parameter #2 b of fun Line_4_jupyter.CalculatorTools.plus(kotlin.Double, kotlin.Double): kotlin.String=20.0})\n",
      "Tool called: plus, args=VarArgs(args={parameter #1 a of fun Line_4_jupyter.CalculatorTools.plus(kotlin.Double, kotlin.Double): kotlin.String=5.0, parameter #2 b of fun Line_4_jupyter.CalculatorTools.plus(kotlin.Double, kotlin.Double): kotlin.String=5.0})\n",
      "Tool called: minus, args=VarArgs(args={parameter #1 a of fun Line_4_jupyter.CalculatorTools.minus(kotlin.Double, kotlin.Double): kotlin.String=2.0, parameter #2 b of fun Line_4_jupyter.CalculatorTools.minus(kotlin.Double, kotlin.Double): kotlin.String=11.0})\n",
      "Tool called: multiply, args=VarArgs(args={parameter #1 a of fun Line_4_jupyter.CalculatorTools.multiply(kotlin.Double, kotlin.Double): kotlin.String=30.0, parameter #2 b of fun Line_4_jupyter.CalculatorTools.multiply(kotlin.Double, kotlin.Double): kotlin.String=10.0})\n",
      "Tool called: divide, args=VarArgs(args={parameter #1 a of fun Line_4_jupyter.CalculatorTools.divide(kotlin.Double, kotlin.Double): kotlin.String=30.0, parameter #2 b of fun Line_4_jupyter.CalculatorTools.divide(kotlin.Double, kotlin.Double): kotlin.String=-9.0})\n",
      "Final result: The result of \\((10 + 20) * (5 + 5) / (2 - 11)\\) is approximately \\(-3.33\\).\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "The result of \\((10 + 20) * (5 + 5) / (2 - 11)\\) is approximately \\(-3.33\\)."
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 8
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Running with Ollama\n",
    "\n",
    "Swap the executor and model if you prefer local inference."
   ]
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-18T19:59:31.787067Z",
     "start_time": "2025-08-18T19:59:25.678910Z"
    }
   },
   "cell_type": "code",
   "source": [
    "val ollamaExecutor: PromptExecutor = simpleOllamaAIExecutor()\n",
    "\n",
    "val ollamaAgentConfig = AIAgentConfig(\n",
    "    prompt = prompt(\"calculator\", LLMParams(temperature = 0.0)) {\n",
    "        system(\"You are a calculator. Always use the provided tools for arithmetic.\")\n",
    "    },\n",
    "    model = OllamaModels.Meta.LLAMA_3_2,\n",
    "    maxAgentIterations = 50\n",
    ")\n",
    "\n",
    "\n",
    "val ollamaAgent = AIAgent(\n",
    "    promptExecutor = ollamaExecutor,\n",
    "    strategy = CalculatorStrategy.strategy,\n",
    "    agentConfig = ollamaAgentConfig,\n",
    "    toolRegistry = toolRegistry\n",
    ")\n",
    "\n",
    "runBlocking {\n",
    "    ollamaAgent.run(\"(10 + 20) * (5 + 5) / (2 - 11)\")\n",
    "}"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Agent says: The result of the expression (10 + 20) * (5 + 5) / (2 - 11) is approximately -33.33.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "If you have any more questions or need further assistance, feel free to ask!"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 9
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Kotlin",
   "language": "kotlin",
   "name": "kotlin"
  },
  "language_info": {
   "name": "kotlin",
   "version": "2.2.20-Beta2",
   "mimetype": "text/x-kotlin",
   "file_extension": ".kt",
   "pygments_lexer": "kotlin",
   "codemirror_mode": "text/x-kotlin",
   "nbconvert_exporter": ""
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
